{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "684fddbc",
   "metadata": {},
   "source": [
    "# Train a Model\n",
    "This tutorial provides an example of how to train a model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f7d979c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dafa4c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ampligraph\n",
    "# Benchmark datasets are under ampligraph.datasets module\n",
    "from ampligraph.datasets import load_fb15k_237\n",
    "# load fb15k-237 dataset\n",
    "dataset = load_fb15k_237()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "87ed1243",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metal device set to: Apple M1 Pro\n",
      "\n",
      "systemMemory: 32.00 GB\n",
      "maxCacheSize: 10.67 GB\n",
      "\n",
      "Epoch 1/10\n",
      "29/29 [==============================] - 3s 94ms/step - loss: 17685.6055\n",
      "Epoch 2/10\n",
      "29/29 [==============================] - 0s 14ms/step - loss: 17684.5527\n",
      "Epoch 3/10\n",
      "29/29 [==============================] - 1s 41ms/step - loss: 17684.2188\n",
      "Epoch 4/10\n",
      "29/29 [==============================] - 0s 13ms/step - loss: 17684.0547\n",
      "Epoch 5/10\n",
      "29/29 [==============================] - 1s 22ms/step - loss: 17683.9531\n",
      "Epoch 6/10\n",
      "29/29 [==============================] - 1s 35ms/step - loss: 17683.8867\n",
      "Epoch 7/10\n",
      "29/29 [==============================] - 1s 42ms/step - loss: 17683.8398\n",
      "Epoch 8/10\n",
      "29/29 [==============================] - 2s 76ms/step - loss: 17683.8047\n",
      "Epoch 9/10\n",
      "29/29 [==============================] - 0s 13ms/step - loss: 17683.7773\n",
      "Epoch 10/10\n",
      "29/29 [==============================] - 0s 10ms/step - loss: 17683.7539\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x29d330820>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Import the KGE model\n",
    "from ampligraph.latent_features import ScoringBasedEmbeddingModel\n",
    "\n",
    "# create the model with transe scoring function\n",
    "model = ScoringBasedEmbeddingModel(eta=5, \n",
    "                                     k=300,\n",
    "                                     scoring_type='Random')\n",
    "\n",
    "\n",
    "# compile the model with loss and optimizer\n",
    "model.compile(optimizer='adam', loss='multiclass_nll')\n",
    "\n",
    "# fit the model to data.\n",
    "model.fit(dataset['train'],\n",
    "             batch_size=10000,\n",
    "             epochs=10)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e5b4dc3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "206/206 [==============================] - 5s 24ms/step\n",
      "MR: 6628.270672277131\n",
      "MRR: 0.0008434279835981804\n",
      "hits@1: 0.0\n",
      "hits@10: 0.0\n"
     ]
    }
   ],
   "source": [
    "# Evaluate on the test set\n",
    "ranks = model.evaluate(dataset['test'], # test set\n",
    "                       batch_size=100, # evaluation batch size\n",
    "                       corrupt_side='s,o')\n",
    "\n",
    "# Import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "2e69f3670cdad0193847aaa0b77be56c05c951fcbdd384ff882dde0464f4de76"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
